﻿using System;
using System.Collections.Generic;
using System.Threading;
using HAPI;
using HuginApiAddonsCS.Extensions;
using HuginApiAddonsCS.Policy;

namespace HuginApiAddonsCS.I_DID
{
    public class IdidEnterModels
    {
        /// <summary>
        /// Extracts the policy trees from DID's before calling EnterPolicyTreesIntoIdid to add them to the input
        /// I-DID for saving, threading is used to save time for loading many nets or large nets
        /// </summary>
        /// <param name="ididNetFile">input I-DID net file</param>
        /// <param name="outNet">output I-DID net file</param>
        /// <param name="didNetFiles">list of DID net files</param>
        /// <param name="ajPrefix">agent j action prefix</param>
        /// <param name="modPrefix">agent j model prefix</param>
        /// <param name="ojPrefix">agent j observation prefix</param>
        public void EnterNetModelsIntoIdid(string ididNetFile, string outNet, List<string> didNetFiles, string ajPrefix, string modPrefix, string ojPrefix)
        {
            ParseListener pl = new DefaultClassParseListener();
            Domain inputIdid = new Domain(ididNetFile, pl);

            List<PolicyNode> policyNodes = new List<PolicyNode>();

            foreach (string file in didNetFiles)
            {
                Thread t = new Thread(delegate()
                {
                    PolicyNode rootNode = LoadIdidDomainPolicy(file, ajPrefix, ojPrefix);
                    lock (policyNodes)
                    {
                        policyNodes.Add(rootNode);
                    }
                });
                t.Start();
            }

            while (policyNodes.Count != didNetFiles.Count)
            {
                Console.Write("\rLoad policies progress {0}%   ", ((policyNodes.Count/didNetFiles.Count) * 100));
            }
            Console.Write("\rLoad policies progress {0}%   ", ((policyNodes.Count / didNetFiles.Count) * 100));

            Console.WriteLine("\nEntering policies into I-DID");
            inputIdid = EnterPolicyTreesIntoIdid(inputIdid, policyNodes, ajPrefix, modPrefix, ojPrefix);
            Console.WriteLine("\nSaving I-DID net file");
            inputIdid.SaveAsNet(outNet);
        }

        /// <summary>
        /// Loads policy trees from csv policy files before adding them to the I-DID
        /// </summary>
        /// <param name="ididNetFile">input I-DID net file</param>
        /// <param name="outNet">output I-DID net file</param>
        /// <param name="didNetFiles">list of DID net files</param>
        /// <param name="ajPrefix">agent j action prefix</param>
        /// <param name="modPrefix">agent j model prefix</param>
        /// <param name="ojPrefix">agent j observation prefix</param>
        public void EnterCsvModelsIntoIdid(string ididNetFile, string outNet, List<string> didNetFiles, string ajPrefix, string modPrefix, string ojPrefix)
        {
            //ToDo
            throw new NotImplementedException();
        }

        /// <summary>
        /// Loads policies from xml trees before adding them to the I-DID
        /// </summary>
        /// <param name="ididNetFile">input I-DID net file</param>
        /// <param name="outNet">output I-DID net file</param>
        /// <param name="didNetFiles">list of DID net files</param>
        /// <param name="ajPrefix">agent j action prefix</param>
        /// <param name="modPrefix">agent j model prefix</param>
        /// <param name="ojPrefix">agent j observation prefix</param>
        public void EnterXmlModelsIntoIdid(string ididNetFile, string outNet, List<string> didNetFiles, string ajPrefix, string modPrefix, string ojPrefix)
        {
            //ToDo
            throw new NotImplementedException();
        }

        /// <summary>
        /// Adds each policy tree to the I-DID domain, loops through each tree 1 at a time
        /// and adds its tree by getting nodes from the tree at each time step
        /// </summary>
        /// <param name="inputDID">input I-DID</param>
        /// <param name="inputPolicies">input trees</param>
        /// <param name="ajPrefix">agent j action prefix</param>
        /// <param name="modPrefix">model prefix</param>
        /// <param name="ojPrefix">agent j observation prefix</param>
        /// <returns>filled in domain</returns>
        public Domain EnterPolicyTreesIntoIdid(Domain inputDID, List<PolicyNode> inputPolicies, string ajPrefix, string modPrefix, string ojPrefix)
        {
            bool first = true;
            int firstStepDelta = inputDID.GetFirstTimeStep();
            int modelCount = 0;

            foreach (PolicyNode tree in inputPolicies)
            {
                string modelName = "Mod" + modelCount;
                //for each step add node histories to domain - need to prefix state with model number
                for (int i = 0; i < tree.GetTreeLength(); i++)
                {
                    List<PolicyNode> levelNodes = tree.GetNodesForLevel(i + 1);
                    DiscreteNode modelNode = (DiscreteNode)inputDID.GetNodeByName(modPrefix + (i + firstStepDelta));

                    if (first)
                    {
                        //Restart state count for first model
                        modelNode.SetNumberOfStates((uint)levelNodes.Count);
                        for (int j = 0; j < levelNodes.Count; j++)
                        {
                            string history = levelNodes[j].GetNodeHistory(null);
                            modelNode.SetStateLabel((uint)j, modelName + history);
                        }
                    }
                    else
                    {
                        //Append state count if not first
                        uint pastStateCount = modelNode.GetNumberOfStates();
                        modelNode.SetNumberOfStates(modelNode.GetNumberOfStates() + (uint)levelNodes.Count);

                        for (int j = 0; j < levelNodes.Count; j++)
                        {
                            string history = levelNodes[j].GetNodeHistory(null);
                            modelNode.SetStateLabel((uint)(j + pastStateCount), modelName + history);
                        }
                    }
                }
                first = false;

                modelCount ++;
            }

            for (int i = inputDID.GetFirstTimeStep(); i <= inputDID.GetLastTimeStep(); i++)
            {
                //Need to get parents and stated for model number
                //Set aj policy/probabilities for history
                DiscreteNode timeStepActionNode = (DiscreteNode)inputDID.GetNodeByName(ajPrefix + i);
                DiscreteNode timeStepModelNode = (DiscreteNode) inputDID.GetNodeByName(modPrefix + i);

                timeStepActionNode.SetJActionNodeTablesFromTree(inputPolicies);
                
                timeStepModelNode.ResetModelNodeTableFromParents((DiscreteNode)inputDID.GetNodeByName(modPrefix + (i - 1)), 
                                                                    (DiscreteNode)inputDID.GetNodeByName(ojPrefix + i));
            }

            return inputDID;
        }

        /// <summary>
        /// Enters policy trees into a domain file
        /// </summary>
        /// <param name="inputDomain">input domain file</param>
        /// <param name="outputDomain">output domain file</param>
        /// <param name="inputPolicies">input policy trees</param>
        /// <param name="ajPrefix">agent j action prefix</param>
        /// <param name="modPrefix">model prefix</param>
        /// <param name="ojPrefix">agent j observation prefix</param>
        /// <returns>modified domain</returns>
        public Domain EnterPolicyTreesIntoIdid(string inputDomain, string outputDomain, List<PolicyNode> inputPolicies, string ajPrefix,
            string modPrefix, string ojPrefix)
        {
            ParseListener pl = new DefaultClassParseListener();
            Domain inputIdid = new Domain(inputDomain, pl);

            //Calls random fill on each policy tree to fill in any potential gaps
            //so the tree can be easily entered into output I-DID
            List<string> potentialActions = new List<string>();
            DiscreteNode actionNode = (DiscreteNode)inputIdid.GetNodeByName(ajPrefix + 1);

            for (uint i = 0; i < actionNode.GetNumberOfStates(); i++)
            {
                potentialActions.Add(actionNode.GetStateLabel(i));
            }

            foreach (PolicyNode root in inputPolicies)
            {
                root.RandomFill(potentialActions);
            }

            inputIdid = EnterPolicyTreesIntoIdid(inputIdid, inputPolicies, ajPrefix, modPrefix, ojPrefix);

            inputIdid.SaveAsNet(outputDomain);

            return inputIdid;
        }

        /// <summary>
        /// Loads the policy for a given input domain file and prefixes
        /// </summary>
        /// <param name="domain">input domain file</param>
        /// <param name="ajPrefix">action prefix</param>
        /// <param name="ojPrefix">observation prefix</param>
        /// <returns>policy tree</returns>
        private PolicyNode LoadIdidDomainPolicy(string domain, string ajPrefix, string ojPrefix)
        {
            ParseListener pl = new DefaultClassParseListener();
            Domain inputDid = new Domain(domain, pl);
            return inputDid.GetPolicyTree(ajPrefix, ojPrefix);
        }
    }
}
